import random
import torch
from torch import nn as nn
# Import PyTorch and matplotlib
from torch import nn # nn contains all of PyTorch's building blocks for neural networks
import matplotlib.pyplot as plt
from pathlib import Path

# Check PyTorch version
# print(torch.__version__)

# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

def plot_predictions(train_data,
                     train_labels,
                     test_data,
                     test_labels,
                     predictions=None):
  """
  Plots training data, test data and compares predictions.
  """
  plt.figure(figsize=(10, 7))

  # Plot training data in blue
  plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")

  # Plot test data in green
  plt.scatter(test_data, test_labels, c="g", s=4, label="Testing data")

  if predictions is not None:
    # Plot the predictions in red (predictions were made on the test data)
    plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")

  # Show the legend
  plt.legend(prop={"size": 14});


# 继承 nn.Module 类构建模型
class LinearRegressionModelV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_layer = nn.Linear(in_features=1,
                                      out_features=1)

    # 定义向前传播的计算方式
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear_layer(x)

# Create weight and bias
weight = 0.7
bias = 0.3

# Create range values
start = 0
end = 1
step = 0.02

# Create X and y (features and labels)
X = torch.arange(start, end, step).unsqueeze(dim=1) # without unsqueeze, errors will happen later on (shapes within linear layers)
y = weight * X + bias
# print(X[:10], y[:10])
# Split data
train_split = int(0.8 * len(X))
X_train, y_train = X[:train_split], y[:train_split]
X_test, y_test = X[train_split:], y[train_split:]

len(X_train), len(y_train), len(X_test), len(y_test)



# 可选，使用manual_seed设置固定的随机值
torch.manual_seed(42)
model_1 = LinearRegressionModelV2()
# print(model_1, model_1.state_dict())
# Check model device
print(next(model_1.parameters()).device)
# 传递到我们之前定义的device上
model_1.to(device)
print(next(model_1.parameters()).device)

# Create loss function
loss_fn = nn.L1Loss()
# Create optimizer
optimizer = torch.optim.SGD(params=model_1.parameters(), # optimize newly created model's parameters
                            lr=0.01)

torch.manual_seed(42)

# Set the number of epochs
epochs = 100

# Put data on the available device
# Without this, error will happen (not all model/data on device)
X_train = X_train.to(device)
X_test = X_test.to(device)
y_train = y_train.to(device)
y_test = y_test.to(device)

for epoch in range(epochs):
    ### Training
    model_1.train() # train mode is on by default after construction
    # 1. Forward pass
    y_pred = model_1(X_train)
    # 2. Calculate loss
    loss = loss_fn(y_pred, y_train)
    # 3. Zero grad optimizer
    optimizer.zero_grad()
    # 4. Loss backward
    loss.backward()
    # 5. Step the optimizer
    optimizer.step()
    ### Testing
    model_1.eval() # put the model in evaluation mode for testing (inference)
    # 1. Forward pass
    with torch.no_grad():
        test_pred = model_1(X_test)

        # 2. Calculate the loss
        test_loss = loss_fn(test_pred, y_test)

    if epoch % 10 == 0:
        print(f"Epoch: {epoch} | Train loss: {loss} | Test loss: {test_loss}")

# 查询模型的参数
from pprint import pprint # pprint = pretty print, see: https://docs.python.org/3/library/pprint.html
print("现在模型的内部参数（ `weights` 和 `bias` ）：")
pprint(model_1.state_dict())
print("\n原始模型的内部参数（ `weights` 和 `bias` ）")
print(f"weights: {weight}, bias: {bias}")

# 1. 将模型设置为评估模式
model_1.eval()
# 2. 设置为推理模式
with torch.inference_mode():
  # 3. 确保所有的对象在同一设备
  # 以防万一，可以使用to(device)同一设备
  # model_0.to(device)
  # X_test = X_test.to(device)
  y_preds = model_1(X_test)

# plot_predictions(predictions=y_preds) # -> won't work... data not on CPU
# Put data on the CPU and plot it
plot_predictions(predictions=y_preds.cpu())

# 1. Create models directory
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

# 2. Create model save path
MODEL_NAME = "01_pytorch_workflow_model_1.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# 3. Save the model state dict
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj=model_1.state_dict(), # only saving the state_dict() only saves the models learned parameters
           f=MODEL_SAVE_PATH)


# if __name__ == '__main__':
#     print("hello world")